# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #  
#
#' @title  Train classifier on labelings induced from annotations 
#'          crowd-sourced in the first round of coding
#' @author Hauke Licht
#
# +~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~+~ #

# setup ----

# load packages 
library(readr)
library(data.table)
library(dplyr)
library(tidyr)
library(purrr)
library(caret)
library(glmnet)

# define dir paths
base_path <- file.path(".")
data_path <- file.path(base_path, "data")
fits_path <- file.path(data_path, "fits")
labelings_path <- file.path(data_path, "intermediate", "labelings")
training_path <- file.path(data_path, "intermediate", "training")

# custom functions
helper_path <- file.path(base_path, "code", "helpers")
source(file.path(helper_path, "custom_train_summary.R"))

# class label mappings
label_abbr <- c(
  "yes-general" = '"General"'
  , "yes-specific" = '"Specific"'
  , "yes-unsure" = '"Ambigious"'
  , "no" = '"No"'
)

altlabels <- label_map <- c(
  "yes-general" = "general"
  , "yes-specific" = "specific"
  , "yes-unsure" = "unsure"
  , "no" = "no"
)

# load data ----

## (English) political tweets w/ clusters ----
clustered_tweets <- read_rds(file.path(data_path, "fits", "political_en_tweets_clustered.rds"))

# add item ID
clustered_tweets$item_id <- with(clustered_tweets, sprintf("%s_%s_%s", country_iso3c, user_id, status_id))

## labeled tweets ----
fp <- file.path(labelings_path, "dawidskene_labelings_sample_1.csv")
labeled_tweets <- read_csv(fp, col_types = "cciccccdidddddddcc")

# remove tweets with invalid labelings
(n_ca <- sum(labeled_tweets$labeling == "cannot-answer"))
labeled_tweets <- filter(labeled_tweets, labeling != "cannot-answer")

count(labeled_tweets, labeling)

# load/create test-training data ----

fp <- file.path(training_path, "training_data_sample_1.rds")

if (file.exists(fp)) {
  data <- read_rds(fp)
} else {
  # load tweet features
  # note: need to rely on old tweets (prior to update in June/May 2023) for reproducibility
  tfeats <- read_rds(file.path(data_path, "input", "parl_party_tweets_features_2020-04-24.rds"))
  class(tfeats)
  
  # build training data frame
  key_cols <- c("country_iso3c", "party_id", "user_id", "status_id")
  load(file.path(helper_path, "politicaltweets", "training.features.rda"))
  feature_vars <- training.features$colname
  
  tfeats[, tweet_type := factor(case_when(is_reply~"reply", is_quote~"quote", TRUE~"tweet"))]
  tfeats[, nchar_limit:= factor(date > lubridate::ymd("2017-11-07"), c(F, T), c("140", "280"))]
  
  # keep only relevant variables
  tfeats <- tfeats[, .SD, .SDcols = c(key_cols, feature_vars[feature_vars %in% names(tfeats)])]
  glimpse(tfeats)
  
  # add item ID
  tfeats[ , item_id:= sprintf("%s_%s_%s", country_iso3c, user_id, status_id)]
  clustered_tweets$item_id <- with(clustered_tweets, sprintf("%s_%s_%s", country_iso3c, user_id, status_id))
  
  # get features for subset of labeled tweets
  model_matrix <- tfeats[item_id %in% labeled_tweets$item_id, ]
  
  # join
  model_matrix <- labeled_tweets %>% # "cannot-answer" already discarded
    left_join(model_matrix) %>%
    left_join(select(clustered_tweets, item_id, matches("^ic\\d{1,3}"))) %>%
    filter(!duplicated(item_id))
  
  # verify
  table(duplicated(model_matrix$item_id))
  table(model_matrix$labeling)
  
  # determine train and validation set and CV fold indexes
  n_folds <- n_repeats <- 10L
  
  set.seed(1234)
  
  data_splits <- model_matrix %>%
    group_by(labeling) %>%
    mutate(test_ = sample(c(T, F), size = n(), replace = TRUE, prob = c(.2, .8))) %>%
    ungroup()
  
  dim(data_splits)
  
  # inspect test-validation distribution by label class
  count(data_splits, labeling, test_) %>%
    pivot_wider(names_from = "test_", values_from = "n") %>%
    mutate(total = rowSums(.[-1])) %>%
    mutate(across(2:3, list(prop = ~./total)))
  
  # ensure that no tweet contained in both test and training sets
  data_splits %>%
    group_by(status_id) %>%
    filter(n_distinct(test_) > 1) %>%
    nrow()
  
  training_data <- filter(data_splits, !test_)
  dim(training_data)
  
  # obtain indexes for 10-times repeated 10-fold CV
  reps <- replicate(
    # repeat 10 times:
    n_repeats
    # sample CV fold indexes (index indicates holdout-set membership)
    , sample(1:n_folds, nrow(training_data), replace = TRUE)
    # return as list
    , simplify = FALSE
  )
  
  # helper functions
  which_is <- function(x,y) which(x == y)
  which_is_not <- function(x,y) which(x != y)
  
  # determine CV-validation and -train sets
  
  ## indexes of held-out CV fold samples within repetitions
  rep_cv_holdout <- lapply(reps, function(r) lapply(1:n_folds, which_is, x = r))
  
  ## indexes of CV-fold training samples within repetitions
  rep_cv_folds <- lapply(reps, function(r) lapply(1:n_folds, which_is_not, x = r))
  
  
  # capture feature names and indexes
  col_names <- colnames(training_data)
  outcome_idx <- which(col_names %in% "labeling")
  
  feature_vars <- vapply(feature_vars, function(v) which(v == col_names), NA_integer_)
  
  data <- list(
    n = list(
      total = nrow(data_splits)
      , training = nrow(training_data)
      , test = nrow(data_splits) - nrow(training_data)
      , cv_folds = setNames(
        lapply(rep_cv_folds, function(r) setNames(lengths(r), sprintf("fold%02d", 1:n_folds)))
        , sprintf("repeat%02d", 1:n_repeats)
      )
      , cv_holdout = setNames(
        lapply(rep_cv_holdout, function(r) setNames(lengths(r), sprintf("fold%02d", 1:n_folds)))
        , sprintf("repeat%02d", 1:n_repeats)
      )
    )
    , indexes = list(
      training = which(!data_splits$test_)
      , test = which(data_splits$test_)
      , cv_folds = setNames(
        purrr::flatten(rep_cv_folds)
        , sprintf("repeat%02d.fold%02d", rep(1:n_repeats, each = n_folds), rep(1:n_folds, times = n_repeats))
      )
      , cv_holdout = setNames(
        purrr::flatten(rep_cv_holdout)
        , sprintf("repeat%02d.fold%02d", rep(1:n_repeats, each = n_folds), rep(1:n_folds, times = n_repeats))
      )
    )
    , training = select(training_data, -test_)
    , test = select(filter(data_splits, test_), -test_)
    , outcome_var = setNames(outcome_idx, "labeling")
    , outcomes = list(
      training = filter(data_splits, !test_) %>% {setNames(.$labeling, .$item_id)}
      , test = filter(data_splits, test_) %>% {setNames(.$labeling, .$item_id)}
    )
    , feature_vars = list(
      manual = feature_vars[-grep("^(ic|pc|e)\\d{1,4}$", names(feature_vars))]
      , ics = feature_vars[grep("^ic\\d{1,3}$", names(feature_vars))]
    )
    , metadata_vars = setNames(1:16, names(model_matrix)[1:16])
    , cv_holdout_var = "cv_holdout_"
    , seed = 1234L
  )
  
  # write to disk
  write_rds(data, fp)
}

# inspect
table(data$training$labeling)
table(data$test$labeling)

# get/set training control parameters ----
ctrl_path <- file.path(training_path, "train_control_cv_folds_sample_1.rds")

if (file.exists(ctrl_path)) {
  ctrl <- read_rds(ctrl_path)
} else {

  # set control parameters used during training
  n_repeats <- length(data$n$cv_folds)
  n_folds <- length(data$n$cv_folds[[1]])
  
  ctrl <- trainControl(
    method = "repeatedcv"
    , number = n_folds
    , repeats = n_repeats
    , search= "grid"
    , verboseIter = TRUE
    , returnData = FALSE
    , returnResamp = "all"
    , savePredictions = "final"
    , classProbs = TRUE
    , summaryFunction = multiClassSum
    , selectionFunction = "best"
    , preProcOptions = list()
    , sampling = NULL
    # when you use `index`, the parameters `number` and `repeats` are ignored: https://github.com/topepo/caret/issues/584
    , index = data$indexes$cv_folds
    , indexOut = data$indexes$cv_holdout
    , trim = FALSE
    , allowParallel = TRUE
    , seeds = list()
  )

  # manually set seeds (CV-folds time repeates + one for the final model)
  seeds_path <- file.path(training_path, "workerseeds_cv_folds_sample_1.txt")

  if (file.exists(seeds_path)) {
    
    seeds <- read_lines(seeds_path)
    seeds <- map(strsplit(seeds, " "), as.integer)

  } else {
    set.seed(1234)

    seeds <- c(
      replicate(n_repeats*n_folds, sample.int(1000, 100), simplify = F)
      , sample.int(1000, 1)
    )

    write_lines(map_chr(seeds, paste, collapse = " "), seeds_path)
  }

  ctrl$seeds <- seeds

  rm(seeds)

  write_rds(ctrl, ctrl_path)
}

# train/load trained model ----

glmnet_fit_path <- file.path(fits_path, "glmnet_classifier_sample_1_tweets.rds")

if (file.exists(glmnet_fit_path)) {
  
  # read trained model from disk
  glmnet_grid_search <- read_rds(glmnet_fit_path)
  
  glmnet_grid <- glmnet_grid_search$results %>% 
    select(1:2) %>% 
    unique()
  
  n_alpha <- length(unique(glmnet_grid$alpha))
  n_lambda <- unique(count(glmnet_grid, alpha)$n)
  
} else {
  
  # construct tuning grid
  require(glmnet, quietly = TRUE)
  
  # helper function
  seq_lamda <- function(.x, .y, .alpha, length, ...) {
    stopifnot(all(.alpha <= 1), all(.alpha >= 0))
    .x <- model.matrix(~., .x)[,-1]
    
    ret <- list()
    for (a in .alpha) {
      g1 <- glmnet::glmnet(
        x = .x
        , y = .y
        , alpha = a
        , ...
      )
      ret[[length(ret)+1L]] <- seq(min(g1$lambda), max(g1$lambda), length.out = length)
    }
    
    data.frame(
      alpha = rep(.alpha, each = length)
      , lambda = unlist(ret)
    )
  }
  
  n_alpha = 5
  n_lambda = 5
  
  train_dat <- select(data$training, data$outcome_var, data$feature_vars$manual, data$feature_vars$ics)
  train_dat[[names(data$outcome_var)]] <- factor(train_dat[[names(data$outcome_var)]], levels = names(altlabels), labels = altlabels)
  
  glmnet_grid <- seq_lamda(
    .x = train_dat[-1]
    , .y = train_dat[[1]]
    , .alpha = seq(0, 1, length.out = n_alpha)
    , length = n_lambda
    , family = "multinomial"
  )
  
  # train
  require(parallel, quietly = TRUE)
  require(doParallel, quietly = TRUE)
  doParallel::registerDoParallel(cores = parallel::detectCores())
  
  st <- Sys.time()
  glmnet_grid_search <- caret::train(
    labeling ~ .
    , data = train_dat
    , method = "glmnet"
    , metric = "Mean_BalancedAccuracy"
    , maximize = TRUE
    , preProc = c("center", "scale")
    , trControl = ctrl
    , tuneGrid = glmnet_grid
  )
  glmnet_grid_search$runtime <- Sys.time() - st
  
  write_rds(glmnet_grid_search, glmnet_fit_path)
}

# inspect CV results ----
these_metrics <- c("Balanced Accuracy", "Sensitivity", "Specificity", "Precision", "F1")

glmnet_grid_search_res <- glmnet_grid_search$results %>% 
  pivot_longer(-c(1:2)) %>% 
  filter(grepl(paste(gsub(" ", "", these_metrics), collapse = "|"), name)) %>%
  mutate(
    is_sd = factor(grepl("SD$", name), c(T, F), c("sd", "mean"))
    , name = sub("SD$", "", name)
  ) %>% 
  pivot_wider(names_from = "is_sd", values_from = "value") %>% 
  separate(name, c("metric", "class"), "__", fill = "right") %>% 
  separate(metric, c("stat", "metric"), "_", fill = "left") %>% 
  mutate(
    what = ifelse(is.na(class), stat, class)
    , metric = gsub("(?<=\\w)(?=[A-Z])", " ", metric, perl = TRUE)
  ) %>% 
  select(-stat, -class)

best_glmnet <- glmnet_grid_search$results[which.max(glmnet_grid_search$results$Mean_BalancedAccuracy), 1:2]

glmnet_grid_search_res %>% 
  filter(alpha == best_glmnet$alpha) %>% 
  filter(lambda == best_glmnet$lambda) %>% 
  filter(what != "WMean") %>% 
  pivot_wider(names_from = "what", values_from = c("mean", "sd")) %>% 
  mutate_at(vars(metric), factor, these_metrics) %>% 
  select(metric, ends_with("Mean"), ends_with("general"), ends_with("specific"), ends_with("unsure"), ends_with("no")) %>% 
  arrange(metric)

# evaluate on test set ----
test_set <- data.frame(obs = data$test[[data$outcome_var]])
test_set$obs <- factor(test_set$obs, names(altlabels), altlabels)
test_set$pred <- predict(glmnet_grid_search, data$test)
test_set <- cbind(test_set, predict(glmnet_grid_search, data$test, type = "prob"))

test_res <- multiClassSum(test_set, lev = altlabels)

test_res <- tibble::enframe(test_res) %>% 
  filter(grepl(paste(gsub(" ", "", these_metrics), collapse = "|"), name)) %>%
  mutate(
    is_sd = factor(grepl("SD$", name), c(T, F), c("sd", "mean"))
    , name = sub("SD$", "", name)
  ) %>% 
  pivot_wider(names_from = "is_sd", values_from = "value") %>% 
  separate(name, c("metric", "class"), "__", fill = "right") %>% 
  separate(metric, c("stat", "metric"), "_", fill = "left") %>% 
  mutate(
    what = ifelse(is.na(class), stat, class)
    , metric = gsub("(?<=\\w)(?=[A-Z])", " ", metric, perl = TRUE)
  ) %>% 
  select(-stat, -class)

pivot_wider(test_res, names_from = "metric", values_from = "mean")
